from ril.policy import *
from ril.grid_world.rule import grid_world_rule
from ril.spacewar.rule import space_war_rule
from ril.flappy_bird.rule import flappy_bird_rule
from ril.pong.rule import pong_rule
from ril.mcc.rule import mcc_rule
from ril.cart_pole import cart_pole_net, cart_pole_policy
from ril.grid_world import grid_world_net, grid_world_policy

agent_table = {
    'CARTPOLE-V1': (cart_pole_net, cart_pole_policy),
    'CARTPOLE-V0': (cart_pole_net, cart_pole_policy),
    'GRIDWORLD-V0': (grid_world_net, grid_world_policy),
}


def factory_builder(task: str, agent: str):
    assert task.upper() in agent_table
    net_dict, policy_fn = agent_table[task.upper()]
    return net_dict[agent.lower()], policy_fn(agent.lower())
